# This file reads takes a simple astro flux and propagtes them through the earth using nuSQuIDS. 
# It saves the propagated fluxes to a .npz file so they can be interpolated later, and outputs a plot for verification. 
# adapted by A. Wen from the nuSQuIDS demo file.

# usage: python astro_flux_propagation.py

# imports 
################################################################################################################

# nuSQuIDS: https://github.com/arguelles/nuSQuIDS
import nuSQuIDS as nsq 
# nuflux: https://docs.icecube.aq/nuflux/main/index.html
import nuflux

try:
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    # mpl.rc('font', family='serif', size=20)
    # if hasattr(plt, "style"):
    #     plt.style.use('./paper.mplstyle')
except ImportError:
    print("matplotlib not found, disabling plotting")
    from unittest.mock import MagicMock
    plt = MagicMock()

import numpy as np


# define flux here
################################################################################################################
# astro flux function
flux_type='astro'
def astro_flux(E_true):#E_true is in GeV
    return (0.787e-18)*(E_true / 100000)**(-2.5)
# this is the total flux. assume a 1:1:1 flavor ratio, and a 1:1 charge ratio
# this is also per flavor, per charge. 

print('Making astro fluxes')

# nusquids setup
################################################################################################################
units = nsq.Const()

interactions = True

E_min = 50.0*units.GeV
E_max = 5000000*units.GeV
E_nodes = 150
energy_nodes = nsq.logspace(E_min,E_max,E_nodes)

cth_min = -1.0
cth_max = 0.1
cth_nodes = 100
cth_nodes = nsq.linspace(cth_min,cth_max,cth_nodes)

neutrino_flavors = 3

nsq_astro = nsq.nuSQUIDSAtm(cth_nodes,energy_nodes,neutrino_flavors,nsq.NeutrinoType.both,interactions)

# loop to populate the nuSQUIDSAtm object with the astro fluxes for both charges and all 3 flavors. 
################################################################################################################
AtmInitialFlux = np.zeros((len(cth_nodes),len(energy_nodes),2,neutrino_flavors))
for ic,cth in enumerate(nsq_astro.GetCosthRange()):
    for ie,E in enumerate(nsq_astro.GetERange()):
        E_GeV = E / 1e9
        AtmInitialFlux[ic][ie][0][0] = astro_flux(E_GeV) # nue
        AtmInitialFlux[ic][ie][1][0] = astro_flux(E_GeV) # bar nue
        AtmInitialFlux[ic][ie][0][1] = astro_flux(E_GeV) # nu mu
        AtmInitialFlux[ic][ie][1][1] = astro_flux(E_GeV) # bar nu mu
        AtmInitialFlux[ic][ie][0][2] = astro_flux(E_GeV) # nu tau
        AtmInitialFlux[ic][ie][1][2] = astro_flux(E_GeV) # bar nu tau

# nusquids propagation
################################################################################################################   
# for the astro fluxes, we do not consider oscillations
nsq_astro.Set_IncludeOscillations(False)

nsq_astro.Set_initial_state(AtmInitialFlux,nsq.Basis.flavor)

nsq_astro.Set_ProgressBar(True) # progress bar will be printed on terminal

nsq_astro.Set_rel_error(1.0e-15)
nsq_astro.Set_abs_error(1.0e-15)

nsq_astro.EvolveState()

erange = nsq_astro.GetERange()
cthrange = nsq_astro.GetCosthRange()

# make a plot to show the effect of the propagation
################################################################################################################
neutype = 1
phi_mubar = [nsq_astro.EvalFlavor(1,-0.5,EE,neutype) for EE in erange]
neutype = 0
phi_mu = [nsq_astro.EvalFlavor(1,-0.5,EE,neutype) for EE in erange]

phi_mu_surface = [(1/6)*astro_flux(EE/1e9) for EE in erange]

phi_mubar_surface = [(1/6)*astro_flux(EE/1e9) for EE in erange]

plt.figure(figsize = (6,6))

plt.plot(erange/1e9, phi_mu, lw = 2.5, color = "cornflowerblue", label = r"$\nu_\mu$ flux at detector")
plt.plot(erange/1e9, phi_mubar, lw = 2.5, color = "darkblue", label = r"$\overline{\nu}_\mu$ flux at detector")

plt.plot(erange/1e9, phi_mu_surface, lw = 1.5, color = "cornflowerblue", label = r"$\nu_\mu$ flux at surface", linestyle='--')

plt.plot(erange/1e9, phi_mubar_surface, lw = 1.5, color = "navy", label = r"$\overline{\nu}_\mu$ flux at surface", linestyle='-.')

plt.loglog()

plt.xlim(erange[0]/1e9,erange[-1]/1e9)
plt.xlabel(r"$E_\nu \: [{\rm GeV}]$")
plt.ylabel(r"$\phi^{astro}_\nu (E_\nu, \cos(\theta)=-0.5) \: [1/({\rm GeV}\:{\rm s}\:{\rm cm}^2\:{\rm sr})]$")

plt.grid()

plt.legend()

plt.savefig('astro.pdf')

# save the propagated fluxes as a function of energy and theta, they can be interpolated and eval'ed later: 
# (we care about muon flavor only)
################################################################################################################################
AstroMuFinalFlux = np.zeros((2,len(cth_nodes),len(energy_nodes)))
Enodes = np.asarray(energy_nodes) / 1e9
Cthnodes = np.asarray(cth_nodes)
for ic,cth in enumerate(nsq_astro.GetCosthRange()):
    for ie,E in enumerate(nsq_astro.GetERange()):
        AstroMuFinalFlux[0][ic][ie] = nsq_astro.EvalFlavor(1,cth,E,0)
        AstroMuFinalFlux[1][ic][ie] = nsq_astro.EvalFlavor(1,cth,E,1)

np.savez('AstroMuFinalFlux.npz', cth_nodes=Cthnodes, energy_nodes=Enodes, flux=AstroMuFinalFlux)
print('Propagated fluxes saved to '+str('AstroMuFinalFlux.npz'))